{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 导入必要的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-13T11:38:00.577786Z",
     "start_time": "2018-07-13T11:38:00.567386Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "import random\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.models as models\n",
    "import shutil\n",
    "from glob import glob\n",
    "\n",
    "from tensorboardX import SummaryWriter\n",
    "\n",
    "import numpy as np\n",
    "import multiprocessing\n",
    "\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "\n",
    "import horovod.torch as hvd\n",
    "import torch.utils.data.distributed\n",
    "\n",
    "from utils import *\n",
    "from models import *\n",
    "import time\n",
    "\n",
    "from pprint import pprint\n",
    "display = pprint\n",
    "\n",
    "hvd.init()\n",
    "torch.cuda.set_device(hvd.local_rank())\n",
    "\n",
    "device = torch.device(\"cuda:%s\" %hvd.local_rank() if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-13T11:38:00.744883Z",
     "start_time": "2018-07-13T11:38:00.737156Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model_name: metanet_base32_style25_tv1e-07_l21e-05_taghvd, rank: 0\n"
     ]
    }
   ],
   "source": [
    "is_hvd = False\n",
    "tag = 'nohvd'\n",
    "base = 32\n",
    "style_weight = 50\n",
    "content_weight = 1\n",
    "tv_weight = 1e-6\n",
    "epochs = 22\n",
    "\n",
    "batch_size = 8\n",
    "width = 256\n",
    "\n",
    "verbose_hist_batch = 100\n",
    "verbose_image_batch = 800\n",
    "\n",
    "model_name = f'metanet_base{base}_style{style_weight}_tv{tv_weight}_tag{tag}'\n",
    "print(f'model_name: {model_name}, rank: {hvd.rank()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-12T13:04:56.541964Z",
     "start_time": "2018-07-12T13:04:56.535774Z"
    }
   },
   "outputs": [],
   "source": [
    "def rmrf(path):\n",
    "    try:\n",
    "        shutil.rmtree(path)\n",
    "    except:\n",
    "        pass\n",
    "\n",
    "for f in glob('runs/*/.AppleDouble'):\n",
    "    rmrf(f)\n",
    "\n",
    "rmrf('runs/' + model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-13T10:38:51.871437Z",
     "start_time": "2018-07-13T10:38:43.789881Z"
    }
   },
   "outputs": [],
   "source": [
    "vgg16 = models.vgg16(pretrained=True)\n",
    "vgg16 = VGG(vgg16.features[:23]).to(device).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-13T10:38:51.925705Z",
     "start_time": "2018-07-13T10:38:51.874457Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {'downsampling.5': 18496,\n",
       "             'downsampling.9': 73856,\n",
       "             'residuals.0.conv.1': 147584,\n",
       "             'residuals.0.conv.5': 147584,\n",
       "             'residuals.1.conv.1': 147584,\n",
       "             'residuals.1.conv.5': 147584,\n",
       "             'residuals.2.conv.1': 147584,\n",
       "             'residuals.2.conv.5': 147584,\n",
       "             'residuals.3.conv.1': 147584,\n",
       "             'residuals.3.conv.5': 147584,\n",
       "             'residuals.4.conv.1': 147584,\n",
       "             'residuals.4.conv.5': 147584,\n",
       "             'upsampling.2': 73792,\n",
       "             'upsampling.7': 18464})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform_net = TransformNet(base).to(device)\n",
    "transform_net.get_param_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-13T10:38:54.307510Z",
     "start_time": "2018-07-13T10:38:51.954926Z"
    },
    "code_folding": [],
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "metanet = MetaNet(transform_net.get_param_dict()).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 载入数据集\n",
    "\n",
    "> During training, each content image or style image is resized to keep the smallest dimension in the range [256, 480], and randomly cropped regions of size 256 × 256.\n",
    "\n",
    "## 载入 COCO 数据集和 WikiArt 数据集\n",
    "\n",
    "> The batch size of content images is 8 and the meta network is trained for 20 iterations before changing the style image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-13T11:38:09.383610Z",
     "start_time": "2018-07-13T11:38:08.371037Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset ImageFolder\n",
      "    Number of datapoints: 23806\n",
      "    Root Location: /home/ypw/WikiArt/\n",
      "    Transforms (if any): Compose(\n",
      "                             RandomResizedCrop(size=(256, 256), scale=(0.5333, 1), ratio=(1, 1), interpolation=PIL.Image.BILINEAR)\n",
      "                             ToTensor()\n",
      "                             Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
      "                         )\n",
      "    Target Transforms (if any): None\n",
      "--------------------\n",
      "Dataset ImageFolder\n",
      "    Number of datapoints: 164062\n",
      "    Root Location: /home/ypw/COCO/\n",
      "    Transforms (if any): Compose(\n",
      "                             RandomResizedCrop(size=(256, 256), scale=(0.5333, 1), ratio=(1, 1), interpolation=PIL.Image.BILINEAR)\n",
      "                             ToTensor()\n",
      "                             Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
      "                         )\n",
      "    Target Transforms (if any): None\n"
     ]
    }
   ],
   "source": [
    "data_transform = transforms.Compose([\n",
    "    transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)), \n",
    "    transforms.ToTensor(), \n",
    "    tensor_normalizer\n",
    "])\n",
    "\n",
    "style_dataset = torchvision.datasets.ImageFolder('/home/ypw/WikiArt/', transform=data_transform)\n",
    "content_dataset = torchvision.datasets.ImageFolder('/home/ypw/COCO/', transform=data_transform)\n",
    "\n",
    "if is_hvd:\n",
    "    train_sampler = torch.utils.data.distributed.DistributedSampler(\n",
    "        content_dataset, num_replicas=hvd.size(), rank=hvd.rank())\n",
    "    content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, \n",
    "        num_workers=multiprocessing.cpu_count(),sampler=train_sampler)\n",
    "else:\n",
    "    content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, \n",
    "        shuffle=True, num_workers=multiprocessing.cpu_count())\n",
    "\n",
    "if not is_hvd or hvd.rank() == 0:\n",
    "    print(style_dataset)\n",
    "    print('-'*20)\n",
    "    print(content_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试 infer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-13T09:42:47.537476Z",
     "start_time": "2018-07-13T09:42:47.379446Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features:\n",
      "[torch.Size([4, 64, 256, 256]),\n",
      " torch.Size([4, 128, 128, 128]),\n",
      " torch.Size([4, 256, 64, 64]),\n",
      " torch.Size([4, 512, 32, 32])]\n",
      "weights:\n",
      "[torch.Size([4, 18496]),\n",
      " torch.Size([4, 73856]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 147584]),\n",
      " torch.Size([4, 73792]),\n",
      " torch.Size([4, 18464])]\n",
      "transformed_images:\n",
      "torch.Size([4, 3, 256, 256])\n"
     ]
    }
   ],
   "source": [
    "metanet.eval()\n",
    "transform_net.eval()\n",
    "\n",
    "rands = torch.rand(4, 3, 256, 256).to(device)\n",
    "features = vgg16(rands);\n",
    "weights = metanet(mean_std(features));\n",
    "transform_net.set_weights(weights)\n",
    "transformed_images = transform_net(torch.rand(4, 3, 256, 256).to(device));\n",
    "\n",
    "if not is_hvd or hvd.rank() == 0:\n",
    "    print('features:')\n",
    "    display([x.shape for x in features])\n",
    "    \n",
    "    print('weights:')\n",
    "    display([x.shape for x in weights.values()])\n",
    "\n",
    "    print('transformed_images:')\n",
    "    display(transformed_images.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 初始化一些变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-12T13:05:07.481869Z",
     "start_time": "2018-07-12T13:05:07.398188Z"
    }
   },
   "outputs": [],
   "source": [
    "visualization_style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device)\n",
    "visualization_content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-12T13:05:08.288833Z",
     "start_time": "2018-07-12T13:05:07.483858Z"
    }
   },
   "outputs": [],
   "source": [
    "if not is_hvd or hvd.rank() == 0:\n",
    "    for f in glob('runs/*/.AppleDouble'):\n",
    "        rmrf(f)\n",
    "\n",
    "    rmrf('runs/' + model_name)\n",
    "    writer = SummaryWriter('runs/'+model_name)\n",
    "else:\n",
    "    writer = SummaryWriter('/tmp/'+model_name)\n",
    "\n",
    "visualization_style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device)\n",
    "visualization_content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device)\n",
    "\n",
    "writer.add_image('content_image', recover_tensor(visualization_content_images), 0)\n",
    "writer.add_graph(transform_net, (rands, ))\n",
    "\n",
    "del rands, features, weights, transformed_images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-12T13:05:08.334236Z",
     "start_time": "2018-07-12T13:05:08.329306Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "trainable_params = {}\n",
    "trainable_param_shapes = {}\n",
    "for model in [vgg16, transform_net, metanet]:\n",
    "    for name, param in model.named_parameters():\n",
    "        if param.requires_grad:\n",
    "            trainable_params[name] = param\n",
    "            trainable_param_shapes[name] = param.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-12T13:05:09.472661Z",
     "start_time": "2018-07-12T13:05:08.337215Z"
    }
   },
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(trainable_params.values(), 1e-3)\n",
    "\n",
    "if is_hvd:\n",
    "    optimizer = hvd.DistributedOptimizer(optimizer, \n",
    "                                         named_parameters=trainable_params.items())\n",
    "    params = transform_net.state_dict()\n",
    "    params.update(metanet.state_dict())\n",
    "    hvd.broadcast_parameters(params, root_rank=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-12T13:06:43.549811Z",
     "start_time": "2018-07-12T13:05:09.476595Z"
    },
    "code_folding": [],
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "n_batch = len(content_data_loader)\n",
    "metanet.train()\n",
    "transform_net.train()\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    smoother = defaultdict(Smooth)\n",
    "    with tqdm(enumerate(content_data_loader), total=n_batch) as pbar:\n",
    "        for batch, (content_images, _) in pbar:\n",
    "            n_iter = epoch*n_batch + batch\n",
    "            \n",
    "            # 每 20 个 batch 随机挑选一张新的风格图像，计算其特征\n",
    "            if batch % 20 == 0:\n",
    "                style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device)\n",
    "                style_features = vgg16(style_image)\n",
    "                style_mean_std = mean_std(style_features)\n",
    "            \n",
    "            # 检查纯色\n",
    "            x = content_images.cpu().numpy()\n",
    "            if (x.min(-1).min(-1) == x.max(-1).max(-1)).any():\n",
    "                continue\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            # 使用风格图像生成风格模型\n",
    "            weights = metanet(mean_std(style_features))\n",
    "            transform_net.set_weights(weights, 0)\n",
    "            \n",
    "            # 使用风格模型预测风格迁移图像\n",
    "            content_images = content_images.to(device)\n",
    "            transformed_images = transform_net(content_images)\n",
    "\n",
    "            # 使用 vgg16 计算特征\n",
    "            content_features = vgg16(content_images)\n",
    "            transformed_features = vgg16(transformed_images)\n",
    "            transformed_mean_std = mean_std(transformed_features)\n",
    "            \n",
    "            # content loss\n",
    "            content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2])\n",
    "            \n",
    "            # style loss\n",
    "            style_loss = style_weight * F.mse_loss(transformed_mean_std, \n",
    "                                                   style_mean_std.expand_as(transformed_mean_std))\n",
    "            \n",
    "            # total variation loss\n",
    "            y = transformed_images\n",
    "            tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + \n",
    "                                    torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))\n",
    "            \n",
    "            # 求和\n",
    "            loss = content_loss + style_loss + tv_loss \n",
    "            \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            smoother['content_loss'] += content_loss.item()\n",
    "            smoother['style_loss'] += style_loss.item()\n",
    "            smoother['tv_loss'] += tv_loss.item()\n",
    "            smoother['loss'] += loss.item()\n",
    "            \n",
    "            max_value = max([x.max().item() for x in weights.values()])\n",
    "        \n",
    "            writer.add_scalar('loss/loss', loss, n_iter)\n",
    "            writer.add_scalar('loss/content_loss', content_loss, n_iter)\n",
    "            writer.add_scalar('loss/style_loss', style_loss, n_iter)\n",
    "            writer.add_scalar('loss/total_variation', tv_loss, n_iter)\n",
    "            writer.add_scalar('loss/max', max_value, n_iter)\n",
    "            \n",
    "            s = 'Epoch: {} '.format(epoch+1)\n",
    "            s += 'Content: {:.2f} '.format(smoother['content_loss'])\n",
    "            s += 'Style: {:.1f} '.format(smoother['style_loss'])\n",
    "            s += 'Loss: {:.2f} '.format(smoother['loss'])\n",
    "            s += 'Max: {:.2f}'.format(max_value)\n",
    "            \n",
    "            if (batch + 1) % verbose_image_batch == 0:\n",
    "                transform_net.eval()\n",
    "                visualization_transformed_images = transform_net(visualization_content_images)\n",
    "                transform_net.train()\n",
    "                visualization_transformed_images = torch.cat([style_image, visualization_transformed_images])\n",
    "                writer.add_image('debug', recover_tensor(visualization_transformed_images), n_iter)\n",
    "                del visualization_transformed_images\n",
    "            \n",
    "            if (batch + 1) % verbose_hist_batch == 0:\n",
    "                for name, param in weights.items():\n",
    "                    writer.add_histogram('transform_net.'+name, param.clone().cpu().data.numpy(), \n",
    "                                         n_iter, bins='auto')\n",
    "                \n",
    "                for name, param in transform_net.named_parameters():\n",
    "                    writer.add_histogram('transform_net.'+name, param.clone().cpu().data.numpy(), \n",
    "                                         n_iter, bins='auto')\n",
    "                \n",
    "                for name, param in metanet.named_parameters():\n",
    "                    l = name.split('.')\n",
    "                    l.remove(l[-1])\n",
    "                    writer.add_histogram('metanet.'+'.'.join(l), param.clone().cpu().data.numpy(), \n",
    "                                         n_iter, bins='auto')\n",
    "\n",
    "            pbar.set_description(s)\n",
    "            \n",
    "            del transformed_images, weights\n",
    "        \n",
    "    if not is_hvd or hvd.rank() == 0:\n",
    "        torch.save(metanet.state_dict(), 'checkpoints/{}_{}.pth'.format(model_name, epoch+1))\n",
    "        torch.save(transform_net.state_dict(), \n",
    "                   'checkpoints/{}_transform_net_{}.pth'.format(model_name, epoch+1))\n",
    "        \n",
    "        torch.save(metanet.state_dict(), 'models/{}.pth'.format(model_name))\n",
    "        torch.save(transform_net.state_dict(), 'models/{}_transform_net.pth'.format(model_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": "40",
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "position": {
    "height": "441px",
    "left": "934px",
    "right": "20px",
    "top": "120px",
    "width": "333px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
